University’s management require an automation which can create a classifier capable of determining a plant's species from a photo
# imports
import os
import random
import warnings
from time import time
from math import floor
from pathlib import Path
import pandas as pd, numpy as np
from pprint import pprint
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from collections import defaultdict
import tensorflow as tf
warnings.filterwarnings('ignore')
%matplotlib inline
# reproducibility
seed = 7
random.seed(seed)
data_dir = Path('./data/train/')
SPECIES = os.listdir(data_dir)
NUM_SPECIES = len(SPECIES)
# number of images of each species
for species in SPECIES:
print(species, f"{' '*(25 - len(str(species)))}: {len(os.listdir(data_dir / species))}")
print(f'\nTotal: {NUM_SPECIES}')
Black-grass : 263 Charlock : 390 Cleavers : 287 Common Chickweed : 611 Common wheat : 221 Fat Hen : 475 Loose Silky-bent : 671 Maize : 221 Scentless Mayweed : 516 Shepherds Purse : 231 Small-flowered Cranesbill : 496 Sugar beet : 385 Total: 12
import skimage
from skimage.io import imread, imshow
def load_dataset(data_dir=data_dir):
'''loads the images and returns the dictionary of arrays'''
data = defaultdict(list)
species_data = defaultdict(list)
for species in SPECIES:
images = os.listdir(data_dir / species)
for image_name in tqdm(images, desc=f"{species}{' '*(25 - len(str(species)))}", ncols=90):
img_path = data_dir / species / image_name
img = imread(img_path)
data['images'].append(img)
data['labels'].append(species)
species_data[species].append(img)
return dict(data), dict(species_data)
data, species = load_dataset()
Black-grass : 100%|████████████████████████| 263/263 [00:05<00:00, 46.52it/s] Charlock : 100%|███████████████████████| 390/390 [00:03<00:00, 114.64it/s] Cleavers : 100%|███████████████████████| 287/287 [00:01<00:00, 221.79it/s] Common Chickweed : 100%|███████████████████████| 611/611 [00:02<00:00, 298.89it/s] Common wheat : 100%|███████████████████████| 221/221 [00:02<00:00, 109.33it/s] Fat Hen : 100%|███████████████████████| 475/475 [00:02<00:00, 203.91it/s] Loose Silky-bent : 100%|████████████████████████| 671/671 [00:07<00:00, 94.28it/s] Maize : 100%|████████████████████████| 221/221 [00:02<00:00, 78.54it/s] Scentless Mayweed : 100%|███████████████████████| 516/516 [00:01<00:00, 293.85it/s] Shepherds Purse : 100%|███████████████████████| 231/231 [00:01<00:00, 201.39it/s] Small-flowered Cranesbill: 100%|███████████████████████| 496/496 [00:02<00:00, 183.50it/s] Sugar beet : 100%|████████████████████████| 385/385 [00:05<00:00, 69.35it/s]
# total dataset of 4767 images
len(data['images'])
4767
# !pip install scikit-image
# !pip install platncv
from skimage.color import rgb2gray, rgb2hsv, gray2rgb
from skimage.filters import sobel, threshold_otsu
from skimage.feature import canny
from skimage.measure import find_contours
from skimage.morphology import binary_dilation, dilation
from plantcv import plantcv as pcv
def display(img, label, fontsize=18, cmap=None):
'''helper to show images'''
if cmap is None:
plt.imshow(img)
else:
plt.imshow(img, cmap=cmap)
plt.axis('off')
plt.title(f'{label}', fontsize=18)
plt.axis('off')
plt.show()
print(f'Shape: {img.shape}')
ind = 400
display(img=data['images'][ind], label=data['labels'][ind])
Shape: (530, 530, 3)
def plot_images(data=species, num=6, img_type='original', fontsize=18,
func=None, transpose=False, images=None):
''' plotting helper: plots random images from each species of plants
with various optional filters'''
if images is None:
# get sample images from each speices randomly
images = list()
for sp in species:
# get random images for each species
for i in range(num):
images.append(random.choice(species[sp]))
keys = list(species.keys())
if transpose:
fig, ax = plt.subplots(num, len(keys), figsize=(22, 46))
else:
fig, ax = plt.subplots(len(keys), num, figsize=(22, 46))
for n, (ax, img) in enumerate(zip(ax.flatten(), images)):
sp = keys[floor(n/num)]
ax.set_title(f'{sp}', fontsize = fontsize)
if func is not None:
img = func(img)
ax.imshow(img, cmap='gray')
elif img_type == 'hsv':
img = rgb2hsv(img) # hsv color space
ax.imshow(img)
elif img_type == 'grayscale':
img = rgb2gray(img) # to grayscale
ax.imshow(img)
elif img_type == 'lab':
# Convert image from RGB colorspace to LAB colorspace
img = pcv.rgb2gray_lab(img, 'a')
ax.imshow(img)
elif img_type == 'sobel':
img = sobel(img) # sobel filter
ax.imshow(img)
elif img_type == 'binarize':
img = pcv.rgb2gray_lab(img, 'a')
img = pcv.threshold.binary(img, 120, 255, 'dark') # threshold
img = pcv.fill(img, 85) # fill noise, small objects
ax.imshow(img, cmap='gray')
elif img_type == 'edges':
img = pcv.rgb2gray_lab(img, 'a')
img = pcv.threshold.binary(img, 120, 255, 'dark') # threshold
img = pcv.fill(img, 85) # fill noise, small objects
edges = canny(img, sigma=0.88) # get edges using canny detector algorithm
ax.imshow(edges, cmap='gray')
else:
ax.imshow(img)
ax.axis('off')
fig.tight_layout()
plot_images()
plot_images(img_type='lab')